Skip to main content

compio_driver/sys/iour/
mod.rs

1#[cfg_attr(all(doc, docsrs), doc(cfg(all())))]
2#[allow(unused_imports)]
3pub use std::os::fd::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
4use std::{
5    io,
6    os::fd::FromRawFd,
7    pin::Pin,
8    sync::Arc,
9    task::{Poll, Wake, Waker},
10    time::Duration,
11};
12
13use compio_log::{instrument, trace, warn};
14use crossbeam_queue::SegQueue;
15cfg_if::cfg_if! {
16    if #[cfg(feature = "io-uring-cqe32")] {
17        use io_uring::cqueue::Entry32 as CEntry;
18    } else {
19        use io_uring::cqueue::Entry as CEntry;
20    }
21}
22cfg_if::cfg_if! {
23    if #[cfg(feature = "io-uring-sqe128")] {
24        use io_uring::squeue::Entry128 as SEntry;
25    } else {
26        use io_uring::squeue::Entry as SEntry;
27    }
28}
29use io_uring::{
30    IoUring,
31    cqueue::more,
32    opcode::{AsyncCancel, PollAdd},
33    types::{Fd, SubmitArgs, Timespec},
34};
35use slab::Slab;
36
37use crate::{
38    AsyncifyPool, BufferPool, DriverType, Entry, ProactorBuilder,
39    key::{ErasedKey, Key, RefExt},
40    syscall,
41};
42
43mod extra;
44pub use extra::Extra;
45pub(crate) mod op;
46
47pub(crate) fn is_op_supported(code: u8) -> bool {
48    #[cfg(feature = "once_cell_try")]
49    use std::sync::OnceLock;
50
51    #[cfg(not(feature = "once_cell_try"))]
52    use once_cell::sync::OnceCell as OnceLock;
53
54    static PROBE: OnceLock<io_uring::Probe> = OnceLock::new();
55
56    PROBE
57        .get_or_try_init(|| {
58            let mut probe = io_uring::Probe::new();
59
60            io_uring::IoUring::new(2)?
61                .submitter()
62                .register_probe(&mut probe)?;
63
64            std::io::Result::Ok(probe)
65        })
66        .map(|probe| probe.is_supported(code))
67        .unwrap_or_default()
68}
69
70/// The created entry of [`OpCode`].
71pub enum OpEntry {
72    /// This operation creates an io-uring submission entry.
73    Submission(io_uring::squeue::Entry),
74    #[cfg(feature = "io-uring-sqe128")]
75    /// This operation creates an 128-bit io-uring submission entry.
76    Submission128(io_uring::squeue::Entry128),
77    /// This operation is a blocking one.
78    Blocking,
79}
80
81impl OpEntry {
82    fn personality(self, personality: Option<u16>) -> Self {
83        let Some(personality) = personality else {
84            return self;
85        };
86
87        match self {
88            Self::Submission(entry) => Self::Submission(entry.personality(personality)),
89            #[cfg(feature = "io-uring-sqe128")]
90            Self::Submission128(entry) => Self::Submission128(entry.personality(personality)),
91            Self::Blocking => Self::Blocking,
92        }
93    }
94}
95
96impl From<io_uring::squeue::Entry> for OpEntry {
97    fn from(value: io_uring::squeue::Entry) -> Self {
98        Self::Submission(value)
99    }
100}
101
102#[cfg(feature = "io-uring-sqe128")]
103impl From<io_uring::squeue::Entry128> for OpEntry {
104    fn from(value: io_uring::squeue::Entry128) -> Self {
105        Self::Submission128(value)
106    }
107}
108
109/// Abstraction of io-uring operations.
110pub trait OpCode {
111    /// Create submission entry.
112    fn create_entry(self: Pin<&mut Self>) -> OpEntry;
113
114    /// Call the operation in a blocking way. This method will only be called if
115    /// [`create_entry`] returns [`OpEntry::Blocking`].
116    ///
117    /// [`create_entry`]: OpCode::create_entry
118    fn call_blocking(self: Pin<&mut Self>) -> io::Result<usize> {
119        unreachable!("this operation is asynchronous")
120    }
121
122    /// Set the result when it successfully completes.
123    /// The operation stores the result and is responsible to release it if the
124    /// operation is cancelled.
125    ///
126    /// # Safety
127    ///
128    /// Users should not call it.
129    unsafe fn set_result(self: Pin<&mut Self>, _: usize) {}
130}
131
132pub use OpCode as IourOpCode;
133
134/// Low-level driver of io-uring.
135pub(crate) struct Driver {
136    inner: IoUring<SEntry, CEntry>,
137    notifier: Notifier,
138    pool: AsyncifyPool,
139    pool_completed: Arc<SegQueue<Entry>>,
140    buffer_group_ids: Slab<()>,
141    need_push_notifier: bool,
142}
143
144impl Driver {
145    const CANCEL: u64 = u64::MAX;
146    const NOTIFY: u64 = u64::MAX - 1;
147
148    pub fn new(builder: &ProactorBuilder) -> io::Result<Self> {
149        instrument!(compio_log::Level::TRACE, "new", ?builder);
150        trace!("new iour driver");
151        let notifier = Notifier::new()?;
152        let mut io_uring_builder = IoUring::builder();
153        if let Some(sqpoll_idle) = builder.sqpoll_idle {
154            io_uring_builder.setup_sqpoll(sqpoll_idle.as_millis() as _);
155        }
156        if builder.coop_taskrun {
157            io_uring_builder.setup_coop_taskrun();
158        }
159        if builder.taskrun_flag {
160            io_uring_builder.setup_taskrun_flag();
161        }
162
163        let inner = io_uring_builder.build(builder.capacity)?;
164
165        let submitter = inner.submitter();
166
167        if let Some(fd) = builder.eventfd {
168            submitter.register_eventfd(fd)?;
169        }
170
171        Ok(Self {
172            inner,
173            notifier,
174            pool: builder.create_or_get_thread_pool(),
175            pool_completed: Arc::new(SegQueue::new()),
176            buffer_group_ids: Slab::new(),
177            need_push_notifier: true,
178        })
179    }
180
181    pub fn driver_type(&self) -> DriverType {
182        DriverType::IoUring
183    }
184
185    #[allow(dead_code)]
186    pub fn as_iour(&self) -> Option<&Self> {
187        Some(self)
188    }
189
190    pub fn register_personality(&self) -> io::Result<u16> {
191        self.inner.submitter().register_personality()
192    }
193
194    pub fn unregister_personality(&self, personality: u16) -> io::Result<()> {
195        self.inner.submitter().unregister_personality(personality)
196    }
197
198    // Auto means that it choose to wait or not automatically.
199    fn submit_auto(&mut self, timeout: Option<Duration>) -> io::Result<()> {
200        instrument!(compio_log::Level::TRACE, "submit_auto", ?timeout);
201
202        // when taskrun is true, there are completed cqes wait to handle, no need to
203        // block the submit
204        let want_sqe = if self.inner.submission().taskrun() {
205            0
206        } else {
207            1
208        };
209
210        let res = {
211            // Last part of submission queue, wait till timeout.
212            if let Some(duration) = timeout {
213                let timespec = timespec(duration);
214                let args = SubmitArgs::new().timespec(&timespec);
215                self.inner.submitter().submit_with_args(want_sqe, &args)
216            } else {
217                self.inner.submit_and_wait(want_sqe)
218            }
219        };
220        trace!("submit result: {res:?}");
221        match res {
222            Ok(_) => {
223                if self.inner.completion().is_empty() {
224                    Err(io::ErrorKind::TimedOut.into())
225                } else {
226                    Ok(())
227                }
228            }
229            Err(e) => match e.raw_os_error() {
230                Some(libc::ETIME) => Err(io::ErrorKind::TimedOut.into()),
231                Some(libc::EBUSY) | Some(libc::EAGAIN) => Err(io::ErrorKind::Interrupted.into()),
232                _ => Err(e),
233            },
234        }
235    }
236
237    fn poll_blocking(&mut self) {
238        // Cheaper than pop.
239        if !self.pool_completed.is_empty() {
240            while let Some(entry) = self.pool_completed.pop() {
241                entry.notify();
242            }
243        }
244    }
245
246    fn poll_entries(&mut self) -> bool {
247        self.poll_blocking();
248
249        let mut cqueue = self.inner.completion();
250        cqueue.sync();
251        let has_entry = !cqueue.is_empty();
252        for entry in cqueue {
253            match entry.user_data() {
254                Self::CANCEL => {}
255                Self::NOTIFY => {
256                    let flags = entry.flags();
257                    if !more(flags) {
258                        self.need_push_notifier = true;
259                    }
260                    self.notifier.clear().expect("cannot clear notifier");
261                }
262                _ => create_entry(entry).notify(),
263            }
264        }
265        has_entry
266    }
267
268    pub fn default_extra(&self) -> Extra {
269        Extra::new()
270    }
271
272    pub fn attach(&mut self, _fd: RawFd) -> io::Result<()> {
273        Ok(())
274    }
275
276    pub fn cancel<T>(&mut self, key: Key<T>) {
277        instrument!(compio_log::Level::TRACE, "cancel", ?key);
278        trace!("cancel RawOp");
279        unsafe {
280            #[allow(clippy::useless_conversion)]
281            if self
282                .inner
283                .submission()
284                .push(
285                    &AsyncCancel::new(key.into_raw() as _)
286                        .build()
287                        .user_data(Self::CANCEL)
288                        .into(),
289                )
290                .is_err()
291            {
292                warn!("could not push AsyncCancel entry");
293            }
294        }
295    }
296
297    fn push_raw_with_key(&mut self, entry: SEntry, key: ErasedKey) -> io::Result<()> {
298        let entry = entry.user_data(key.as_raw() as _);
299        self.push_raw(entry)?; // if push failed, do not leak the key. Drop it upon return.
300        key.into_raw();
301        Ok(())
302    }
303
304    fn push_raw(&mut self, entry: SEntry) -> io::Result<()> {
305        loop {
306            let mut squeue = self.inner.submission();
307            match unsafe { squeue.push(&entry) } {
308                Ok(()) => {
309                    squeue.sync();
310                    break Ok(());
311                }
312                Err(_) => {
313                    drop(squeue);
314                    self.poll_entries();
315                    match self.submit_auto(Some(Duration::ZERO)) {
316                        Ok(()) => {}
317                        Err(e)
318                            if matches!(
319                                e.kind(),
320                                io::ErrorKind::TimedOut | io::ErrorKind::Interrupted
321                            ) => {}
322                        Err(e) => return Err(e),
323                    }
324                }
325            }
326        }
327    }
328
329    pub fn push(&mut self, key: ErasedKey) -> Poll<io::Result<usize>> {
330        instrument!(compio_log::Level::TRACE, "push", ?key);
331        let personality = key.borrow().extra().as_iour().get_personality();
332        let entry = key
333            .borrow()
334            .pinned_op()
335            .create_entry()
336            .personality(personality);
337        trace!(?personality, "push RawOp");
338        match entry {
339            OpEntry::Submission(entry) => {
340                #[allow(clippy::useless_conversion)]
341                self.push_raw_with_key(entry.into(), key)?;
342                Poll::Pending
343            }
344            #[cfg(feature = "io-uring-sqe128")]
345            OpEntry::Submission128(entry) => {
346                self.push_raw_with_key(entry, key)?;
347                Poll::Pending
348            }
349            OpEntry::Blocking => loop {
350                if self.push_blocking(key.clone()) {
351                    break Poll::Pending;
352                } else {
353                    self.poll_blocking();
354                }
355            },
356        }
357    }
358
359    fn push_blocking(&mut self, key: ErasedKey) -> bool {
360        let waker = self.waker();
361        let completed = self.pool_completed.clone();
362        // SAFETY: we're submitting into the driver, so it's safe to freeze here.
363        let mut key = unsafe { key.freeze() };
364        self.pool
365            .dispatch(move || {
366                let res = key.pinned_op().call_blocking();
367                completed.push(Entry::new(key.into_inner(), res));
368                waker.wake();
369            })
370            .is_ok()
371    }
372
373    pub fn poll(&mut self, timeout: Option<Duration>) -> io::Result<()> {
374        instrument!(compio_log::Level::TRACE, "poll", ?timeout);
375        // Anyway we need to submit once, no matter there are entries in squeue.
376        trace!("start polling");
377
378        if self.need_push_notifier {
379            #[allow(clippy::useless_conversion)]
380            self.push_raw(
381                PollAdd::new(Fd(self.notifier.as_raw_fd()), libc::POLLIN as _)
382                    .multi(true)
383                    .build()
384                    .user_data(Self::NOTIFY)
385                    .into(),
386            )?;
387            self.need_push_notifier = false;
388        }
389
390        if !self.poll_entries() {
391            self.submit_auto(timeout)?;
392            self.poll_entries();
393        }
394
395        Ok(())
396    }
397
398    pub fn waker(&self) -> Waker {
399        self.notifier.waker()
400    }
401
402    pub fn create_buffer_pool(
403        &mut self,
404        buffer_len: u16,
405        buffer_size: usize,
406    ) -> io::Result<BufferPool> {
407        let buffer_group = self.buffer_group_ids.insert(());
408        if buffer_group > u16::MAX as usize {
409            self.buffer_group_ids.remove(buffer_group);
410
411            return Err(io::Error::new(
412                io::ErrorKind::OutOfMemory,
413                "too many buffer pool allocated",
414            ));
415        }
416
417        let buf_ring = io_uring_buf_ring::IoUringBufRing::new_with_flags(
418            &self.inner,
419            buffer_len,
420            buffer_group as _,
421            buffer_size,
422            0,
423        )?;
424
425        #[cfg(fusion)]
426        {
427            Ok(BufferPool::new_io_uring(crate::IoUringBufferPool::new(
428                buf_ring,
429            )))
430        }
431        #[cfg(not(fusion))]
432        {
433            Ok(BufferPool::new(buf_ring))
434        }
435    }
436
437    /// # Safety
438    ///
439    /// caller must make sure release the buffer pool with correct driver
440    pub unsafe fn release_buffer_pool(&mut self, buffer_pool: BufferPool) -> io::Result<()> {
441        #[cfg(fusion)]
442        let buffer_pool = buffer_pool.into_io_uring();
443
444        let buffer_group = buffer_pool.buffer_group();
445        unsafe { buffer_pool.into_inner().release(&self.inner)? };
446        self.buffer_group_ids.remove(buffer_group as _);
447
448        Ok(())
449    }
450}
451
452impl AsRawFd for Driver {
453    fn as_raw_fd(&self) -> RawFd {
454        self.inner.as_raw_fd()
455    }
456}
457
458fn create_entry(cq_entry: CEntry) -> Entry {
459    let result = cq_entry.result();
460    let result = if result < 0 {
461        let result = if result == -libc::ECANCELED {
462            libc::ETIMEDOUT
463        } else {
464            -result
465        };
466        Err(io::Error::from_raw_os_error(result))
467    } else {
468        Ok(result as _)
469    };
470    let key = unsafe { ErasedKey::from_raw(cq_entry.user_data() as _) };
471    let mut entry = Entry::new(key, result);
472    entry.set_flags(cq_entry.flags());
473
474    entry
475}
476
477fn timespec(duration: std::time::Duration) -> Timespec {
478    Timespec::new()
479        .sec(duration.as_secs())
480        .nsec(duration.subsec_nanos())
481}
482
483#[derive(Debug)]
484struct Notifier {
485    notify: Arc<Notify>,
486}
487
488impl Notifier {
489    /// Create a new notifier.
490    fn new() -> io::Result<Self> {
491        let fd = syscall!(libc::eventfd(0, libc::EFD_CLOEXEC | libc::EFD_NONBLOCK))?;
492        let fd = unsafe { OwnedFd::from_raw_fd(fd) };
493        Ok(Self {
494            notify: Arc::new(Notify::new(fd)),
495        })
496    }
497
498    pub fn clear(&self) -> io::Result<()> {
499        loop {
500            let mut buffer = [0u64];
501            let res = syscall!(libc::read(
502                self.as_raw_fd(),
503                buffer.as_mut_ptr().cast(),
504                std::mem::size_of::<u64>()
505            ));
506            match res {
507                Ok(len) => {
508                    debug_assert_eq!(len, std::mem::size_of::<u64>() as _);
509                    break Ok(());
510                }
511                // Clear the next time:)
512                Err(e) if e.kind() == io::ErrorKind::WouldBlock => break Ok(()),
513                // Just like read_exact
514                Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
515                Err(e) => break Err(e),
516            }
517        }
518    }
519
520    pub fn waker(&self) -> Waker {
521        Waker::from(self.notify.clone())
522    }
523}
524
525impl AsRawFd for Notifier {
526    fn as_raw_fd(&self) -> RawFd {
527        self.notify.fd.as_raw_fd()
528    }
529}
530
531/// A notify handle to the inner driver.
532#[derive(Debug)]
533pub(crate) struct Notify {
534    fd: OwnedFd,
535}
536
537impl Notify {
538    pub(crate) fn new(fd: OwnedFd) -> Self {
539        Self { fd }
540    }
541
542    /// Notify the inner driver.
543    pub fn notify(&self) -> io::Result<()> {
544        let data = 1u64;
545        syscall!(libc::write(
546            self.fd.as_raw_fd(),
547            &data as *const _ as *const _,
548            std::mem::size_of::<u64>(),
549        ))?;
550        Ok(())
551    }
552}
553
554impl Wake for Notify {
555    fn wake(self: Arc<Self>) {
556        self.wake_by_ref();
557    }
558
559    fn wake_by_ref(self: &Arc<Self>) {
560        self.notify().ok();
561    }
562}