compio_driver/iour/
mod.rs

1#[cfg_attr(all(doc, docsrs), doc(cfg(all())))]
2#[allow(unused_imports)]
3pub use std::os::fd::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
4use std::{io, os::fd::FromRawFd, pin::Pin, sync::Arc, task::Poll, time::Duration};
5
6use compio_log::{instrument, trace, warn};
7use crossbeam_queue::SegQueue;
8cfg_if::cfg_if! {
9    if #[cfg(feature = "io-uring-cqe32")] {
10        use io_uring::cqueue::Entry32 as CEntry;
11    } else {
12        use io_uring::cqueue::Entry as CEntry;
13    }
14}
15cfg_if::cfg_if! {
16    if #[cfg(feature = "io-uring-sqe128")] {
17        use io_uring::squeue::Entry128 as SEntry;
18    } else {
19        use io_uring::squeue::Entry as SEntry;
20    }
21}
22use io_uring::{
23    IoUring,
24    cqueue::more,
25    opcode::{AsyncCancel, PollAdd},
26    types::{Fd, SubmitArgs, Timespec},
27};
28use slab::Slab;
29
30use crate::{AsyncifyPool, BufferPool, DriverType, Entry, Key, ProactorBuilder, syscall};
31
32pub(crate) mod op;
33
34/// The created entry of [`OpCode`].
35pub enum OpEntry {
36    /// This operation creates an io-uring submission entry.
37    Submission(io_uring::squeue::Entry),
38    #[cfg(feature = "io-uring-sqe128")]
39    /// This operation creates an 128-bit io-uring submission entry.
40    Submission128(io_uring::squeue::Entry128),
41    /// This operation is a blocking one.
42    Blocking,
43}
44
45impl From<io_uring::squeue::Entry> for OpEntry {
46    fn from(value: io_uring::squeue::Entry) -> Self {
47        Self::Submission(value)
48    }
49}
50
51#[cfg(feature = "io-uring-sqe128")]
52impl From<io_uring::squeue::Entry128> for OpEntry {
53    fn from(value: io_uring::squeue::Entry128) -> Self {
54        Self::Submission128(value)
55    }
56}
57
58/// Abstraction of io-uring operations.
59pub trait OpCode {
60    /// Create submission entry.
61    fn create_entry(self: Pin<&mut Self>) -> OpEntry;
62
63    /// Call the operation in a blocking way. This method will only be called if
64    /// [`create_entry`] returns [`OpEntry::Blocking`].
65    fn call_blocking(self: Pin<&mut Self>) -> io::Result<usize> {
66        unreachable!("this operation is asynchronous")
67    }
68
69    /// Set the result when it successfully completes.
70    /// The operation stores the result and is responsible to release it if the
71    /// operation is cancelled.
72    ///
73    /// # Safety
74    ///
75    /// Users should not call it.
76    unsafe fn set_result(self: Pin<&mut Self>, _: usize) {}
77}
78
79/// Low-level driver of io-uring.
80pub(crate) struct Driver {
81    inner: IoUring<SEntry, CEntry>,
82    notifier: Notifier,
83    pool: AsyncifyPool,
84    pool_completed: Arc<SegQueue<Entry>>,
85    buffer_group_ids: Slab<()>,
86    need_push_notifier: bool,
87}
88
89impl Driver {
90    const CANCEL: u64 = u64::MAX;
91    const NOTIFY: u64 = u64::MAX - 1;
92
93    pub fn new(builder: &ProactorBuilder) -> io::Result<Self> {
94        instrument!(compio_log::Level::TRACE, "new", ?builder);
95        trace!("new iour driver");
96        let notifier = Notifier::new()?;
97        let mut io_uring_builder = IoUring::builder();
98        if let Some(sqpoll_idle) = builder.sqpoll_idle {
99            io_uring_builder.setup_sqpoll(sqpoll_idle.as_millis() as _);
100        }
101        if builder.coop_taskrun {
102            io_uring_builder.setup_coop_taskrun();
103        }
104        if builder.taskrun_flag {
105            io_uring_builder.setup_taskrun_flag();
106        }
107
108        let inner = io_uring_builder.build(builder.capacity)?;
109
110        if let Some(fd) = builder.eventfd {
111            inner.submitter().register_eventfd(fd)?;
112        }
113
114        Ok(Self {
115            inner,
116            notifier,
117            pool: builder.create_or_get_thread_pool(),
118            pool_completed: Arc::new(SegQueue::new()),
119            buffer_group_ids: Slab::new(),
120            need_push_notifier: true,
121        })
122    }
123
124    pub fn driver_type(&self) -> DriverType {
125        DriverType::IoUring
126    }
127
128    // Auto means that it choose to wait or not automatically.
129    fn submit_auto(&mut self, timeout: Option<Duration>) -> io::Result<()> {
130        instrument!(compio_log::Level::TRACE, "submit_auto", ?timeout);
131
132        // when taskrun is true, there are completed cqes wait to handle, no need to
133        // block the submit
134        let want_sqe = if self.inner.submission().taskrun() {
135            0
136        } else {
137            1
138        };
139
140        let res = {
141            // Last part of submission queue, wait till timeout.
142            if let Some(duration) = timeout {
143                let timespec = timespec(duration);
144                let args = SubmitArgs::new().timespec(&timespec);
145                self.inner.submitter().submit_with_args(want_sqe, &args)
146            } else {
147                self.inner.submit_and_wait(want_sqe)
148            }
149        };
150        trace!("submit result: {res:?}");
151        match res {
152            Ok(_) => {
153                if self.inner.completion().is_empty() {
154                    Err(io::ErrorKind::TimedOut.into())
155                } else {
156                    Ok(())
157                }
158            }
159            Err(e) => match e.raw_os_error() {
160                Some(libc::ETIME) => Err(io::ErrorKind::TimedOut.into()),
161                Some(libc::EBUSY) | Some(libc::EAGAIN) => Err(io::ErrorKind::Interrupted.into()),
162                _ => Err(e),
163            },
164        }
165    }
166
167    fn poll_blocking(&mut self) {
168        // Cheaper than pop.
169        if !self.pool_completed.is_empty() {
170            while let Some(entry) = self.pool_completed.pop() {
171                unsafe {
172                    entry.notify();
173                }
174            }
175        }
176    }
177
178    fn poll_entries(&mut self) -> bool {
179        self.poll_blocking();
180
181        let mut cqueue = self.inner.completion();
182        cqueue.sync();
183        let has_entry = !cqueue.is_empty();
184        for entry in cqueue {
185            match entry.user_data() {
186                Self::CANCEL => {}
187                Self::NOTIFY => {
188                    let flags = entry.flags();
189                    if !more(flags) {
190                        self.need_push_notifier = true;
191                    }
192                    self.notifier.clear().expect("cannot clear notifier");
193                }
194                _ => unsafe {
195                    create_entry(entry).notify();
196                },
197            }
198        }
199        has_entry
200    }
201
202    pub fn create_op<T: crate::sys::OpCode + 'static>(&self, op: T) -> Key<T> {
203        Key::new(self.as_raw_fd(), op)
204    }
205
206    pub fn attach(&mut self, _fd: RawFd) -> io::Result<()> {
207        Ok(())
208    }
209
210    pub fn cancel(&mut self, op: &mut Key<dyn crate::sys::OpCode>) {
211        instrument!(compio_log::Level::TRACE, "cancel", ?op);
212        trace!("cancel RawOp");
213        unsafe {
214            #[allow(clippy::useless_conversion)]
215            if self
216                .inner
217                .submission()
218                .push(
219                    &AsyncCancel::new(op.user_data() as _)
220                        .build()
221                        .user_data(Self::CANCEL)
222                        .into(),
223                )
224                .is_err()
225            {
226                warn!("could not push AsyncCancel entry");
227            }
228        }
229    }
230
231    fn push_raw(&mut self, entry: SEntry) -> io::Result<()> {
232        loop {
233            let mut squeue = self.inner.submission();
234            match unsafe { squeue.push(&entry) } {
235                Ok(()) => {
236                    squeue.sync();
237                    break Ok(());
238                }
239                Err(_) => {
240                    drop(squeue);
241                    self.poll_entries();
242                    match self.submit_auto(Some(Duration::ZERO)) {
243                        Ok(()) => {}
244                        Err(e)
245                            if matches!(
246                                e.kind(),
247                                io::ErrorKind::TimedOut | io::ErrorKind::Interrupted
248                            ) => {}
249                        Err(e) => return Err(e),
250                    }
251                }
252            }
253        }
254    }
255
256    pub fn push(&mut self, op: &mut Key<dyn crate::sys::OpCode>) -> Poll<io::Result<usize>> {
257        instrument!(compio_log::Level::TRACE, "push", ?op);
258        let user_data = op.user_data();
259        let op_pin = op.as_op_pin();
260        trace!("push RawOp");
261        match op_pin.create_entry() {
262            OpEntry::Submission(entry) => {
263                #[allow(clippy::useless_conversion)]
264                self.push_raw(entry.user_data(user_data as _).into())?;
265                Poll::Pending
266            }
267            #[cfg(feature = "io-uring-sqe128")]
268            OpEntry::Submission128(entry) => {
269                self.push_raw(entry.user_data(user_data as _))?;
270                Poll::Pending
271            }
272            OpEntry::Blocking => loop {
273                if self.push_blocking(user_data) {
274                    break Poll::Pending;
275                } else {
276                    self.poll_blocking();
277                }
278            },
279        }
280    }
281
282    fn push_blocking(&mut self, user_data: usize) -> bool {
283        let handle = self.handle();
284        let completed = self.pool_completed.clone();
285        self.pool
286            .dispatch(move || {
287                let mut op = unsafe { Key::<dyn crate::sys::OpCode>::new_unchecked(user_data) };
288                let op_pin = op.as_op_pin();
289                let res = op_pin.call_blocking();
290                completed.push(Entry::new(user_data, res));
291                handle.notify().ok();
292            })
293            .is_ok()
294    }
295
296    pub unsafe fn poll(&mut self, timeout: Option<Duration>) -> io::Result<()> {
297        instrument!(compio_log::Level::TRACE, "poll", ?timeout);
298        // Anyway we need to submit once, no matter there are entries in squeue.
299        trace!("start polling");
300
301        if self.need_push_notifier {
302            #[allow(clippy::useless_conversion)]
303            self.push_raw(
304                PollAdd::new(Fd(self.notifier.as_raw_fd()), libc::POLLIN as _)
305                    .multi(true)
306                    .build()
307                    .user_data(Self::NOTIFY)
308                    .into(),
309            )?;
310            self.need_push_notifier = false;
311        }
312
313        if !self.poll_entries() {
314            self.submit_auto(timeout)?;
315            self.poll_entries();
316        }
317
318        Ok(())
319    }
320
321    pub fn handle(&self) -> NotifyHandle {
322        self.notifier.handle()
323    }
324
325    pub fn create_buffer_pool(
326        &mut self,
327        buffer_len: u16,
328        buffer_size: usize,
329    ) -> io::Result<BufferPool> {
330        let buffer_group = self.buffer_group_ids.insert(());
331        if buffer_group > u16::MAX as usize {
332            self.buffer_group_ids.remove(buffer_group);
333
334            return Err(io::Error::new(
335                io::ErrorKind::OutOfMemory,
336                "too many buffer pool allocated",
337            ));
338        }
339
340        let buf_ring = io_uring_buf_ring::IoUringBufRing::new(
341            &self.inner,
342            buffer_len,
343            buffer_group as _,
344            buffer_size,
345        )?;
346
347        #[cfg(fusion)]
348        {
349            Ok(BufferPool::new_io_uring(crate::IoUringBufferPool::new(
350                buf_ring,
351            )))
352        }
353        #[cfg(not(fusion))]
354        {
355            Ok(BufferPool::new(buf_ring))
356        }
357    }
358
359    /// # Safety
360    ///
361    /// caller must make sure release the buffer pool with correct driver
362    pub unsafe fn release_buffer_pool(&mut self, buffer_pool: BufferPool) -> io::Result<()> {
363        #[cfg(fusion)]
364        let buffer_pool = buffer_pool.into_io_uring();
365
366        let buffer_group = buffer_pool.buffer_group();
367        buffer_pool.into_inner().release(&self.inner)?;
368        self.buffer_group_ids.remove(buffer_group as _);
369
370        Ok(())
371    }
372}
373
374impl AsRawFd for Driver {
375    fn as_raw_fd(&self) -> RawFd {
376        self.inner.as_raw_fd()
377    }
378}
379
380fn create_entry(cq_entry: CEntry) -> Entry {
381    let result = cq_entry.result();
382    let result = if result < 0 {
383        let result = if result == -libc::ECANCELED {
384            libc::ETIMEDOUT
385        } else {
386            -result
387        };
388        Err(io::Error::from_raw_os_error(result))
389    } else {
390        Ok(result as _)
391    };
392    let mut entry = Entry::new(cq_entry.user_data() as _, result);
393    entry.set_flags(cq_entry.flags());
394
395    entry
396}
397
398fn timespec(duration: std::time::Duration) -> Timespec {
399    Timespec::new()
400        .sec(duration.as_secs())
401        .nsec(duration.subsec_nanos())
402}
403
404#[derive(Debug)]
405struct Notifier {
406    fd: Arc<OwnedFd>,
407}
408
409impl Notifier {
410    /// Create a new notifier.
411    fn new() -> io::Result<Self> {
412        let fd = syscall!(libc::eventfd(0, libc::EFD_CLOEXEC | libc::EFD_NONBLOCK))?;
413        let fd = unsafe { OwnedFd::from_raw_fd(fd) };
414        Ok(Self { fd: Arc::new(fd) })
415    }
416
417    pub fn clear(&self) -> io::Result<()> {
418        loop {
419            let mut buffer = [0u64];
420            let res = syscall!(libc::read(
421                self.fd.as_raw_fd(),
422                buffer.as_mut_ptr().cast(),
423                std::mem::size_of::<u64>()
424            ));
425            match res {
426                Ok(len) => {
427                    debug_assert_eq!(len, std::mem::size_of::<u64>() as _);
428                    break Ok(());
429                }
430                // Clear the next time:)
431                Err(e) if e.kind() == io::ErrorKind::WouldBlock => break Ok(()),
432                // Just like read_exact
433                Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
434                Err(e) => break Err(e),
435            }
436        }
437    }
438
439    pub fn handle(&self) -> NotifyHandle {
440        NotifyHandle::new(self.fd.clone())
441    }
442}
443
444impl AsRawFd for Notifier {
445    fn as_raw_fd(&self) -> RawFd {
446        self.fd.as_raw_fd()
447    }
448}
449
450/// A notify handle to the inner driver.
451pub struct NotifyHandle {
452    fd: Arc<OwnedFd>,
453}
454
455impl NotifyHandle {
456    pub(crate) fn new(fd: Arc<OwnedFd>) -> Self {
457        Self { fd }
458    }
459
460    /// Notify the inner driver.
461    pub fn notify(&self) -> io::Result<()> {
462        let data = 1u64;
463        syscall!(libc::write(
464            self.fd.as_raw_fd(),
465            &data as *const _ as *const _,
466            std::mem::size_of::<u64>(),
467        ))?;
468        Ok(())
469    }
470}