Skip to main content

compio_driver/sys/iour/
mod.rs

1#[cfg_attr(all(doc, docsrs), doc(cfg(all())))]
2#[allow(unused_imports)]
3pub use std::os::fd::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
4use std::{
5    io,
6    os::fd::FromRawFd,
7    pin::Pin,
8    sync::Arc,
9    task::{Poll, Wake, Waker},
10    time::Duration,
11};
12
13use compio_log::{instrument, trace, warn};
14use crossbeam_queue::SegQueue;
15cfg_if::cfg_if! {
16    if #[cfg(feature = "io-uring-cqe32")] {
17        use io_uring::cqueue::Entry32 as CEntry;
18    } else {
19        use io_uring::cqueue::Entry as CEntry;
20    }
21}
22cfg_if::cfg_if! {
23    if #[cfg(feature = "io-uring-sqe128")] {
24        use io_uring::squeue::Entry128 as SEntry;
25    } else {
26        use io_uring::squeue::Entry as SEntry;
27    }
28}
29use io_uring::{
30    IoUring,
31    cqueue::more,
32    opcode::{AsyncCancel, PollAdd},
33    types::{Fd, SubmitArgs, Timespec},
34};
35use slab::Slab;
36
37use crate::{
38    AsyncifyPool, BufferPool, DriverType, Entry, ProactorBuilder,
39    key::{ErasedKey, Key, RefExt},
40    syscall,
41};
42
43mod extra;
44pub use extra::Extra;
45pub(crate) mod op;
46
47pub(crate) fn is_op_supported(code: u8) -> bool {
48    #[cfg(feature = "once_cell_try")]
49    use std::sync::OnceLock;
50
51    #[cfg(not(feature = "once_cell_try"))]
52    use once_cell::sync::OnceCell as OnceLock;
53
54    static PROBE: OnceLock<io_uring::Probe> = OnceLock::new();
55
56    PROBE
57        .get_or_try_init(|| {
58            let mut probe = io_uring::Probe::new();
59
60            io_uring::IoUring::new(2)?
61                .submitter()
62                .register_probe(&mut probe)?;
63
64            std::io::Result::Ok(probe)
65        })
66        .map(|probe| probe.is_supported(code))
67        .unwrap_or_default()
68}
69
70/// The created entry of [`OpCode`].
71pub enum OpEntry {
72    /// This operation creates an io-uring submission entry.
73    Submission(io_uring::squeue::Entry),
74    #[cfg(feature = "io-uring-sqe128")]
75    /// This operation creates an 128-bit io-uring submission entry.
76    Submission128(io_uring::squeue::Entry128),
77    /// This operation is a blocking one.
78    Blocking,
79}
80
81impl OpEntry {
82    fn personality(self, personality: Option<u16>) -> Self {
83        let Some(personality) = personality else {
84            return self;
85        };
86
87        match self {
88            Self::Submission(entry) => Self::Submission(entry.personality(personality)),
89            #[cfg(feature = "io-uring-sqe128")]
90            Self::Submission128(entry) => Self::Submission128(entry.personality(personality)),
91            Self::Blocking => Self::Blocking,
92        }
93    }
94}
95
96impl From<io_uring::squeue::Entry> for OpEntry {
97    fn from(value: io_uring::squeue::Entry) -> Self {
98        Self::Submission(value)
99    }
100}
101
102#[cfg(feature = "io-uring-sqe128")]
103impl From<io_uring::squeue::Entry128> for OpEntry {
104    fn from(value: io_uring::squeue::Entry128) -> Self {
105        Self::Submission128(value)
106    }
107}
108
109/// Abstraction of io-uring operations.
110///
111/// # Safety
112///
113/// The returned Entry from `create_entry` must be valid until the operation is
114/// completed.
115pub unsafe trait OpCode {
116    /// Create submission entry.
117    fn create_entry(self: Pin<&mut Self>) -> OpEntry;
118
119    /// Call the operation in a blocking way. This method will only be called if
120    /// [`create_entry`] returns [`OpEntry::Blocking`].
121    ///
122    /// [`create_entry`]: OpCode::create_entry
123    fn call_blocking(self: Pin<&mut Self>) -> io::Result<usize> {
124        unreachable!("this operation is asynchronous")
125    }
126
127    /// Set the result when it successfully completes.
128    /// The operation stores the result and is responsible to release it if the
129    /// operation is cancelled.
130    ///
131    /// # Safety
132    ///
133    /// Users should not call it.
134    unsafe fn set_result(self: Pin<&mut Self>, _: usize) {}
135}
136
137pub use OpCode as IourOpCode;
138
139/// Low-level driver of io-uring.
140pub(crate) struct Driver {
141    inner: IoUring<SEntry, CEntry>,
142    notifier: Notifier,
143    pool: AsyncifyPool,
144    pool_completed: Arc<SegQueue<Entry>>,
145    buffer_group_ids: Slab<()>,
146    need_push_notifier: bool,
147}
148
149impl Driver {
150    const CANCEL: u64 = u64::MAX;
151    const NOTIFY: u64 = u64::MAX - 1;
152
153    pub fn new(builder: &ProactorBuilder) -> io::Result<Self> {
154        instrument!(compio_log::Level::TRACE, "new", ?builder);
155        trace!("new iour driver");
156        let notifier = Notifier::new()?;
157        let mut io_uring_builder = IoUring::builder();
158        if let Some(sqpoll_idle) = builder.sqpoll_idle {
159            io_uring_builder.setup_sqpoll(sqpoll_idle.as_millis() as _);
160        }
161        if builder.coop_taskrun {
162            io_uring_builder.setup_coop_taskrun();
163        }
164        if builder.taskrun_flag {
165            io_uring_builder.setup_taskrun_flag();
166        }
167
168        let inner = io_uring_builder.build(builder.capacity)?;
169
170        let submitter = inner.submitter();
171
172        if let Some(fd) = builder.eventfd {
173            submitter.register_eventfd(fd)?;
174        }
175
176        Ok(Self {
177            inner,
178            notifier,
179            pool: builder.create_or_get_thread_pool(),
180            pool_completed: Arc::new(SegQueue::new()),
181            buffer_group_ids: Slab::new(),
182            need_push_notifier: true,
183        })
184    }
185
186    pub fn driver_type(&self) -> DriverType {
187        DriverType::IoUring
188    }
189
190    #[allow(dead_code)]
191    pub fn as_iour(&self) -> Option<&Self> {
192        Some(self)
193    }
194
195    pub fn register_personality(&self) -> io::Result<u16> {
196        self.inner.submitter().register_personality()
197    }
198
199    pub fn unregister_personality(&self, personality: u16) -> io::Result<()> {
200        self.inner.submitter().unregister_personality(personality)
201    }
202
203    // Auto means that it choose to wait or not automatically.
204    fn submit_auto(&mut self, timeout: Option<Duration>) -> io::Result<()> {
205        instrument!(compio_log::Level::TRACE, "submit_auto", ?timeout);
206
207        // when taskrun is true, there are completed cqes wait to handle, no need to
208        // block the submit
209        let want_sqe = if self.inner.submission().taskrun() {
210            0
211        } else {
212            1
213        };
214
215        let res = {
216            // Last part of submission queue, wait till timeout.
217            if let Some(duration) = timeout {
218                let timespec = timespec(duration);
219                let args = SubmitArgs::new().timespec(&timespec);
220                self.inner.submitter().submit_with_args(want_sqe, &args)
221            } else {
222                self.inner.submit_and_wait(want_sqe)
223            }
224        };
225        trace!("submit result: {res:?}");
226        match res {
227            Ok(_) => {
228                if self.inner.completion().is_empty() {
229                    Err(io::ErrorKind::TimedOut.into())
230                } else {
231                    Ok(())
232                }
233            }
234            Err(e) => match e.raw_os_error() {
235                Some(libc::ETIME) => Err(io::ErrorKind::TimedOut.into()),
236                Some(libc::EBUSY) | Some(libc::EAGAIN) => Err(io::ErrorKind::Interrupted.into()),
237                _ => Err(e),
238            },
239        }
240    }
241
242    fn poll_blocking(&mut self) {
243        // Cheaper than pop.
244        if !self.pool_completed.is_empty() {
245            while let Some(entry) = self.pool_completed.pop() {
246                entry.notify();
247            }
248        }
249    }
250
251    fn poll_entries(&mut self) -> bool {
252        self.poll_blocking();
253
254        let mut cqueue = self.inner.completion();
255        cqueue.sync();
256        let has_entry = !cqueue.is_empty();
257        for entry in cqueue {
258            match entry.user_data() {
259                Self::CANCEL => {}
260                Self::NOTIFY => {
261                    let flags = entry.flags();
262                    if !more(flags) {
263                        self.need_push_notifier = true;
264                    }
265                    self.notifier.clear().expect("cannot clear notifier");
266                }
267                _ => create_entry(entry).notify(),
268            }
269        }
270        has_entry
271    }
272
273    pub fn default_extra(&self) -> Extra {
274        Extra::new()
275    }
276
277    pub fn attach(&mut self, _fd: RawFd) -> io::Result<()> {
278        Ok(())
279    }
280
281    pub fn cancel<T>(&mut self, key: Key<T>) {
282        instrument!(compio_log::Level::TRACE, "cancel", ?key);
283        trace!("cancel RawOp");
284        unsafe {
285            #[allow(clippy::useless_conversion)]
286            if self
287                .inner
288                .submission()
289                .push(
290                    &AsyncCancel::new(key.as_raw() as _)
291                        .build()
292                        .user_data(Self::CANCEL)
293                        .into(),
294                )
295                .is_err()
296            {
297                warn!("could not push AsyncCancel entry");
298            }
299        }
300    }
301
302    fn push_raw_with_key(&mut self, entry: SEntry, key: ErasedKey) -> io::Result<()> {
303        let entry = entry.user_data(key.as_raw() as _);
304        self.push_raw(entry)?; // if push failed, do not leak the key. Drop it upon return.
305        key.into_raw();
306        Ok(())
307    }
308
309    fn push_raw(&mut self, entry: SEntry) -> io::Result<()> {
310        loop {
311            let mut squeue = self.inner.submission();
312            match unsafe { squeue.push(&entry) } {
313                Ok(()) => {
314                    squeue.sync();
315                    break Ok(());
316                }
317                Err(_) => {
318                    drop(squeue);
319                    self.poll_entries();
320                    match self.submit_auto(Some(Duration::ZERO)) {
321                        Ok(()) => {}
322                        Err(e)
323                            if matches!(
324                                e.kind(),
325                                io::ErrorKind::TimedOut | io::ErrorKind::Interrupted
326                            ) => {}
327                        Err(e) => return Err(e),
328                    }
329                }
330            }
331        }
332    }
333
334    pub fn push(&mut self, key: ErasedKey) -> Poll<io::Result<usize>> {
335        instrument!(compio_log::Level::TRACE, "push", ?key);
336        let personality = key.borrow().extra().as_iour().get_personality();
337        let entry = key
338            .borrow()
339            .pinned_op()
340            .create_entry()
341            .personality(personality);
342        trace!(?personality, "push RawOp");
343        match entry {
344            OpEntry::Submission(entry) => {
345                if is_op_supported(entry.get_opcode() as _) {
346                    #[allow(clippy::useless_conversion)]
347                    self.push_raw_with_key(entry.into(), key)?;
348                    Poll::Pending
349                } else {
350                    self.push_blocking_loop(key)
351                }
352            }
353            #[cfg(feature = "io-uring-sqe128")]
354            OpEntry::Submission128(entry) => {
355                self.push_raw_with_key(entry, key)?;
356                Poll::Pending
357            }
358            OpEntry::Blocking => self.push_blocking_loop(key),
359        }
360    }
361
362    fn push_blocking_loop(&mut self, key: ErasedKey) -> Poll<io::Result<usize>> {
363        loop {
364            if self.push_blocking(key.clone()) {
365                break Poll::Pending;
366            } else {
367                self.poll_blocking();
368            }
369        }
370    }
371
372    fn push_blocking(&mut self, key: ErasedKey) -> bool {
373        let waker = self.waker();
374        let completed = self.pool_completed.clone();
375        // SAFETY: we're submitting into the driver, so it's safe to freeze here.
376        let mut key = unsafe { key.freeze() };
377        self.pool
378            .dispatch(move || {
379                let res = key.pinned_op().call_blocking();
380                completed.push(Entry::new(key.into_inner(), res));
381                waker.wake();
382            })
383            .is_ok()
384    }
385
386    pub fn poll(&mut self, timeout: Option<Duration>) -> io::Result<()> {
387        instrument!(compio_log::Level::TRACE, "poll", ?timeout);
388        // Anyway we need to submit once, no matter there are entries in squeue.
389        trace!("start polling");
390
391        if self.need_push_notifier {
392            #[allow(clippy::useless_conversion)]
393            self.push_raw(
394                PollAdd::new(Fd(self.notifier.as_raw_fd()), libc::POLLIN as _)
395                    .multi(true)
396                    .build()
397                    .user_data(Self::NOTIFY)
398                    .into(),
399            )?;
400            self.need_push_notifier = false;
401        }
402
403        if !self.poll_entries() {
404            self.submit_auto(timeout)?;
405            self.poll_entries();
406        }
407
408        Ok(())
409    }
410
411    pub fn waker(&self) -> Waker {
412        self.notifier.waker()
413    }
414
415    pub fn create_buffer_pool(
416        &mut self,
417        buffer_len: u16,
418        buffer_size: usize,
419    ) -> io::Result<BufferPool> {
420        let buffer_group = self.buffer_group_ids.insert(());
421        if buffer_group > u16::MAX as usize {
422            self.buffer_group_ids.remove(buffer_group);
423
424            return Err(io::Error::new(
425                io::ErrorKind::OutOfMemory,
426                "too many buffer pool allocated",
427            ));
428        }
429
430        let buf_ring = io_uring_buf_ring::IoUringBufRing::new_with_flags(
431            &self.inner,
432            buffer_len,
433            buffer_group as _,
434            buffer_size,
435            0,
436        )?;
437
438        #[cfg(fusion)]
439        {
440            Ok(BufferPool::new_io_uring(crate::IoUringBufferPool::new(
441                buf_ring,
442            )))
443        }
444        #[cfg(not(fusion))]
445        {
446            Ok(BufferPool::new(buf_ring))
447        }
448    }
449
450    /// # Safety
451    ///
452    /// caller must make sure release the buffer pool with correct driver
453    pub unsafe fn release_buffer_pool(&mut self, buffer_pool: BufferPool) -> io::Result<()> {
454        #[cfg(fusion)]
455        let buffer_pool = buffer_pool.into_io_uring();
456
457        let buffer_group = buffer_pool.buffer_group();
458        unsafe { buffer_pool.into_inner().release(&self.inner)? };
459        self.buffer_group_ids.remove(buffer_group as _);
460
461        Ok(())
462    }
463}
464
465impl AsRawFd for Driver {
466    fn as_raw_fd(&self) -> RawFd {
467        self.inner.as_raw_fd()
468    }
469}
470
471fn create_entry(cq_entry: CEntry) -> Entry {
472    let result = cq_entry.result();
473    let result = if result < 0 {
474        let result = if result == -libc::ECANCELED {
475            libc::ETIMEDOUT
476        } else {
477            -result
478        };
479        Err(io::Error::from_raw_os_error(result))
480    } else {
481        Ok(result as _)
482    };
483    let key = unsafe { ErasedKey::from_raw(cq_entry.user_data() as _) };
484    let mut entry = Entry::new(key, result);
485    entry.set_flags(cq_entry.flags());
486
487    entry
488}
489
490fn timespec(duration: std::time::Duration) -> Timespec {
491    Timespec::new()
492        .sec(duration.as_secs())
493        .nsec(duration.subsec_nanos())
494}
495
496#[derive(Debug)]
497struct Notifier {
498    notify: Arc<Notify>,
499}
500
501impl Notifier {
502    /// Create a new notifier.
503    fn new() -> io::Result<Self> {
504        let fd = syscall!(libc::eventfd(0, libc::EFD_CLOEXEC | libc::EFD_NONBLOCK))?;
505        let fd = unsafe { OwnedFd::from_raw_fd(fd) };
506        Ok(Self {
507            notify: Arc::new(Notify::new(fd)),
508        })
509    }
510
511    pub fn clear(&self) -> io::Result<()> {
512        loop {
513            let mut buffer = [0u64];
514            let res = syscall!(libc::read(
515                self.as_raw_fd(),
516                buffer.as_mut_ptr().cast(),
517                std::mem::size_of::<u64>()
518            ));
519            match res {
520                Ok(len) => {
521                    debug_assert_eq!(len, std::mem::size_of::<u64>() as _);
522                    break Ok(());
523                }
524                // Clear the next time:)
525                Err(e) if e.kind() == io::ErrorKind::WouldBlock => break Ok(()),
526                // Just like read_exact
527                Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
528                Err(e) => break Err(e),
529            }
530        }
531    }
532
533    pub fn waker(&self) -> Waker {
534        Waker::from(self.notify.clone())
535    }
536}
537
538impl AsRawFd for Notifier {
539    fn as_raw_fd(&self) -> RawFd {
540        self.notify.fd.as_raw_fd()
541    }
542}
543
544/// A notify handle to the inner driver.
545#[derive(Debug)]
546pub(crate) struct Notify {
547    fd: OwnedFd,
548}
549
550impl Notify {
551    pub(crate) fn new(fd: OwnedFd) -> Self {
552        Self { fd }
553    }
554
555    /// Notify the inner driver.
556    pub fn notify(&self) -> io::Result<()> {
557        let data = 1u64;
558        syscall!(libc::write(
559            self.fd.as_raw_fd(),
560            &data as *const _ as *const _,
561            std::mem::size_of::<u64>(),
562        ))?;
563        Ok(())
564    }
565}
566
567impl Wake for Notify {
568    fn wake(self: Arc<Self>) {
569        self.wake_by_ref();
570    }
571
572    fn wake_by_ref(self: &Arc<Self>) {
573        self.notify().ok();
574    }
575}