compio_driver/iour/
mod.rs

1#[cfg_attr(all(doc, docsrs), doc(cfg(all())))]
2#[allow(unused_imports)]
3pub use std::os::fd::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
4use std::{io, os::fd::FromRawFd, pin::Pin, sync::Arc, task::Poll, time::Duration};
5
6use compio_log::{instrument, trace, warn};
7use crossbeam_queue::SegQueue;
8cfg_if::cfg_if! {
9    if #[cfg(feature = "io-uring-cqe32")] {
10        use io_uring::cqueue::Entry32 as CEntry;
11    } else {
12        use io_uring::cqueue::Entry as CEntry;
13    }
14}
15cfg_if::cfg_if! {
16    if #[cfg(feature = "io-uring-sqe128")] {
17        use io_uring::squeue::Entry128 as SEntry;
18    } else {
19        use io_uring::squeue::Entry as SEntry;
20    }
21}
22use io_uring::{
23    IoUring,
24    cqueue::more,
25    opcode::{AsyncCancel, PollAdd},
26    types::{Fd, SubmitArgs, Timespec},
27};
28pub(crate) use libc::{sockaddr_storage, socklen_t};
29#[cfg(io_uring)]
30use slab::Slab;
31
32use crate::{AsyncifyPool, BufferPool, Entry, Key, ProactorBuilder, syscall};
33
34pub(crate) mod op;
35
36/// The created entry of [`OpCode`].
37pub enum OpEntry {
38    /// This operation creates an io-uring submission entry.
39    Submission(io_uring::squeue::Entry),
40    #[cfg(feature = "io-uring-sqe128")]
41    /// This operation creates an 128-bit io-uring submission entry.
42    Submission128(io_uring::squeue::Entry128),
43    /// This operation is a blocking one.
44    Blocking,
45}
46
47impl From<io_uring::squeue::Entry> for OpEntry {
48    fn from(value: io_uring::squeue::Entry) -> Self {
49        Self::Submission(value)
50    }
51}
52
53#[cfg(feature = "io-uring-sqe128")]
54impl From<io_uring::squeue::Entry128> for OpEntry {
55    fn from(value: io_uring::squeue::Entry128) -> Self {
56        Self::Submission128(value)
57    }
58}
59
60/// Abstraction of io-uring operations.
61pub trait OpCode {
62    /// Create submission entry.
63    fn create_entry(self: Pin<&mut Self>) -> OpEntry;
64
65    /// Call the operation in a blocking way. This method will only be called if
66    /// [`create_entry`] returns [`OpEntry::Blocking`].
67    fn call_blocking(self: Pin<&mut Self>) -> io::Result<usize> {
68        unreachable!("this operation is asynchronous")
69    }
70
71    /// Set the result when it successfully completes.
72    /// The operation stores the result and is responsible to release it if the
73    /// operation is cancelled.
74    ///
75    /// # Safety
76    ///
77    /// Users should not call it.
78    unsafe fn set_result(self: Pin<&mut Self>, _: usize) {}
79}
80
81/// Low-level driver of io-uring.
82pub(crate) struct Driver {
83    inner: IoUring<SEntry, CEntry>,
84    notifier: Notifier,
85    pool: AsyncifyPool,
86    pool_completed: Arc<SegQueue<Entry>>,
87    #[cfg(io_uring)]
88    buffer_group_ids: Slab<()>,
89}
90
91impl Driver {
92    const CANCEL: u64 = u64::MAX;
93    const NOTIFY: u64 = u64::MAX - 1;
94
95    pub fn new(builder: &ProactorBuilder) -> io::Result<Self> {
96        instrument!(compio_log::Level::TRACE, "new", ?builder);
97        trace!("new iour driver");
98        let notifier = Notifier::new()?;
99        let mut io_uring_builder = IoUring::builder();
100        if let Some(sqpoll_idle) = builder.sqpoll_idle {
101            io_uring_builder.setup_sqpoll(sqpoll_idle.as_millis() as _);
102        }
103        if builder.coop_taskrun {
104            io_uring_builder.setup_coop_taskrun();
105        }
106        if builder.taskrun_flag {
107            io_uring_builder.setup_taskrun_flag();
108        }
109
110        let mut inner = io_uring_builder.build(builder.capacity)?;
111        #[allow(clippy::useless_conversion)]
112        unsafe {
113            inner
114                .submission()
115                .push(
116                    &PollAdd::new(Fd(notifier.as_raw_fd()), libc::POLLIN as _)
117                        .multi(true)
118                        .build()
119                        .user_data(Self::NOTIFY)
120                        .into(),
121                )
122                .expect("the squeue sould not be full");
123        }
124        Ok(Self {
125            inner,
126            notifier,
127            pool: builder.create_or_get_thread_pool(),
128            pool_completed: Arc::new(SegQueue::new()),
129            #[cfg(io_uring)]
130            buffer_group_ids: Slab::new(),
131        })
132    }
133
134    // Auto means that it choose to wait or not automatically.
135    fn submit_auto(&mut self, timeout: Option<Duration>) -> io::Result<()> {
136        instrument!(compio_log::Level::TRACE, "submit_auto", ?timeout);
137
138        // when taskrun is true, there are completed cqes wait to handle, no need to
139        // block the submit
140        let want_sqe = if self.inner.submission().taskrun() {
141            0
142        } else {
143            1
144        };
145
146        let res = {
147            // Last part of submission queue, wait till timeout.
148            if let Some(duration) = timeout {
149                let timespec = timespec(duration);
150                let args = SubmitArgs::new().timespec(&timespec);
151                self.inner.submitter().submit_with_args(want_sqe, &args)
152            } else {
153                self.inner.submit_and_wait(want_sqe)
154            }
155        };
156        trace!("submit result: {res:?}");
157        match res {
158            Ok(_) => {
159                if self.inner.completion().is_empty() {
160                    Err(io::ErrorKind::TimedOut.into())
161                } else {
162                    Ok(())
163                }
164            }
165            Err(e) => match e.raw_os_error() {
166                Some(libc::ETIME) => Err(io::ErrorKind::TimedOut.into()),
167                Some(libc::EBUSY) | Some(libc::EAGAIN) => Err(io::ErrorKind::Interrupted.into()),
168                _ => Err(e),
169            },
170        }
171    }
172
173    fn poll_blocking(&mut self) {
174        // Cheaper than pop.
175        if !self.pool_completed.is_empty() {
176            while let Some(entry) = self.pool_completed.pop() {
177                unsafe {
178                    entry.notify();
179                }
180            }
181        }
182    }
183
184    fn poll_entries(&mut self) -> bool {
185        self.poll_blocking();
186
187        let mut cqueue = self.inner.completion();
188        cqueue.sync();
189        let has_entry = !cqueue.is_empty();
190        for entry in cqueue {
191            match entry.user_data() {
192                Self::CANCEL => {}
193                Self::NOTIFY => {
194                    let flags = entry.flags();
195                    debug_assert!(more(flags));
196                    self.notifier.clear().expect("cannot clear notifier");
197                }
198                _ => unsafe {
199                    create_entry(entry).notify();
200                },
201            }
202        }
203        has_entry
204    }
205
206    pub fn create_op<T: crate::sys::OpCode + 'static>(&self, op: T) -> Key<T> {
207        Key::new(self.as_raw_fd(), op)
208    }
209
210    pub fn attach(&mut self, _fd: RawFd) -> io::Result<()> {
211        Ok(())
212    }
213
214    pub fn cancel(&mut self, op: &mut Key<dyn crate::sys::OpCode>) {
215        instrument!(compio_log::Level::TRACE, "cancel", ?op);
216        trace!("cancel RawOp");
217        unsafe {
218            #[allow(clippy::useless_conversion)]
219            if self
220                .inner
221                .submission()
222                .push(
223                    &AsyncCancel::new(op.user_data() as _)
224                        .build()
225                        .user_data(Self::CANCEL)
226                        .into(),
227                )
228                .is_err()
229            {
230                warn!("could not push AsyncCancel entry");
231            }
232        }
233    }
234
235    fn push_raw(&mut self, entry: SEntry) -> io::Result<()> {
236        loop {
237            let mut squeue = self.inner.submission();
238            match unsafe { squeue.push(&entry) } {
239                Ok(()) => {
240                    squeue.sync();
241                    break Ok(());
242                }
243                Err(_) => {
244                    drop(squeue);
245                    self.poll_entries();
246                    match self.submit_auto(Some(Duration::ZERO)) {
247                        Ok(()) => {}
248                        Err(e)
249                            if matches!(
250                                e.kind(),
251                                io::ErrorKind::TimedOut | io::ErrorKind::Interrupted
252                            ) => {}
253                        Err(e) => return Err(e),
254                    }
255                }
256            }
257        }
258    }
259
260    pub fn push(&mut self, op: &mut Key<dyn crate::sys::OpCode>) -> Poll<io::Result<usize>> {
261        instrument!(compio_log::Level::TRACE, "push", ?op);
262        let user_data = op.user_data();
263        let op_pin = op.as_op_pin();
264        trace!("push RawOp");
265        match op_pin.create_entry() {
266            OpEntry::Submission(entry) => {
267                #[allow(clippy::useless_conversion)]
268                self.push_raw(entry.user_data(user_data as _).into())?;
269                Poll::Pending
270            }
271            #[cfg(feature = "io-uring-sqe128")]
272            OpEntry::Submission128(entry) => {
273                self.push_raw(entry.user_data(user_data as _))?;
274                Poll::Pending
275            }
276            OpEntry::Blocking => loop {
277                if self.push_blocking(user_data) {
278                    break Poll::Pending;
279                } else {
280                    self.poll_blocking();
281                }
282            },
283        }
284    }
285
286    fn push_blocking(&mut self, user_data: usize) -> bool {
287        let handle = self.handle();
288        let completed = self.pool_completed.clone();
289        self.pool
290            .dispatch(move || {
291                let mut op = unsafe { Key::<dyn crate::sys::OpCode>::new_unchecked(user_data) };
292                let op_pin = op.as_op_pin();
293                let res = op_pin.call_blocking();
294                completed.push(Entry::new(user_data, res));
295                handle.notify().ok();
296            })
297            .is_ok()
298    }
299
300    pub unsafe fn poll(&mut self, timeout: Option<Duration>) -> io::Result<()> {
301        instrument!(compio_log::Level::TRACE, "poll", ?timeout);
302        // Anyway we need to submit once, no matter there are entries in squeue.
303        trace!("start polling");
304
305        if !self.poll_entries() {
306            self.submit_auto(timeout)?;
307            self.poll_entries();
308        }
309
310        Ok(())
311    }
312
313    pub fn handle(&self) -> NotifyHandle {
314        self.notifier.handle()
315    }
316
317    #[cfg(io_uring)]
318    pub fn create_buffer_pool(
319        &mut self,
320        buffer_len: u16,
321        buffer_size: usize,
322    ) -> io::Result<BufferPool> {
323        let buffer_group = self.buffer_group_ids.insert(());
324        if buffer_group > u16::MAX as usize {
325            self.buffer_group_ids.remove(buffer_group);
326
327            return Err(io::Error::new(
328                io::ErrorKind::OutOfMemory,
329                "too many buffer pool allocated",
330            ));
331        }
332
333        let buf_ring = io_uring_buf_ring::IoUringBufRing::new(
334            &self.inner,
335            buffer_len,
336            buffer_group as _,
337            buffer_size,
338        )?;
339
340        #[cfg(fusion)]
341        {
342            Ok(BufferPool::new_io_uring(crate::IoUringBufferPool::new(
343                buf_ring,
344            )))
345        }
346        #[cfg(not(fusion))]
347        {
348            Ok(BufferPool::new(buf_ring))
349        }
350    }
351
352    #[cfg(not(io_uring))]
353    pub fn create_buffer_pool(
354        &mut self,
355        buffer_len: u16,
356        buffer_size: usize,
357    ) -> io::Result<BufferPool> {
358        Ok(BufferPool::new(buffer_len, buffer_size))
359    }
360
361    /// # Safety
362    ///
363    /// caller must make sure release the buffer pool with correct driver
364    #[cfg(io_uring)]
365    pub unsafe fn release_buffer_pool(&mut self, buffer_pool: BufferPool) -> io::Result<()> {
366        #[cfg(fusion)]
367        let buffer_pool = buffer_pool.into_io_uring();
368
369        let buffer_group = buffer_pool.buffer_group();
370        buffer_pool.into_inner().release(&self.inner)?;
371        self.buffer_group_ids.remove(buffer_group as _);
372
373        Ok(())
374    }
375
376    /// # Safety
377    ///
378    /// caller must make sure release the buffer pool with correct driver
379    #[cfg(not(io_uring))]
380    pub unsafe fn release_buffer_pool(&mut self, _: BufferPool) -> io::Result<()> {
381        Ok(())
382    }
383}
384
385impl AsRawFd for Driver {
386    fn as_raw_fd(&self) -> RawFd {
387        self.inner.as_raw_fd()
388    }
389}
390
391fn create_entry(cq_entry: CEntry) -> Entry {
392    let result = cq_entry.result();
393    let result = if result < 0 {
394        let result = if result == -libc::ECANCELED {
395            libc::ETIMEDOUT
396        } else {
397            -result
398        };
399        Err(io::Error::from_raw_os_error(result))
400    } else {
401        Ok(result as _)
402    };
403    let mut entry = Entry::new(cq_entry.user_data() as _, result);
404    entry.set_flags(cq_entry.flags());
405
406    entry
407}
408
409fn timespec(duration: std::time::Duration) -> Timespec {
410    Timespec::new()
411        .sec(duration.as_secs())
412        .nsec(duration.subsec_nanos())
413}
414
415#[derive(Debug)]
416struct Notifier {
417    fd: Arc<OwnedFd>,
418}
419
420impl Notifier {
421    /// Create a new notifier.
422    fn new() -> io::Result<Self> {
423        let fd = syscall!(libc::eventfd(0, libc::EFD_CLOEXEC | libc::EFD_NONBLOCK))?;
424        let fd = unsafe { OwnedFd::from_raw_fd(fd) };
425        Ok(Self { fd: Arc::new(fd) })
426    }
427
428    pub fn clear(&self) -> io::Result<()> {
429        loop {
430            let mut buffer = [0u64];
431            let res = syscall!(libc::read(
432                self.fd.as_raw_fd(),
433                buffer.as_mut_ptr().cast(),
434                std::mem::size_of::<u64>()
435            ));
436            match res {
437                Ok(len) => {
438                    debug_assert_eq!(len, std::mem::size_of::<u64>() as _);
439                    break Ok(());
440                }
441                // Clear the next time:)
442                Err(e) if e.kind() == io::ErrorKind::WouldBlock => break Ok(()),
443                // Just like read_exact
444                Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
445                Err(e) => break Err(e),
446            }
447        }
448    }
449
450    pub fn handle(&self) -> NotifyHandle {
451        NotifyHandle::new(self.fd.clone())
452    }
453}
454
455impl AsRawFd for Notifier {
456    fn as_raw_fd(&self) -> RawFd {
457        self.fd.as_raw_fd()
458    }
459}
460
461/// A notify handle to the inner driver.
462pub struct NotifyHandle {
463    fd: Arc<OwnedFd>,
464}
465
466impl NotifyHandle {
467    pub(crate) fn new(fd: Arc<OwnedFd>) -> Self {
468        Self { fd }
469    }
470
471    /// Notify the inner driver.
472    pub fn notify(&self) -> io::Result<()> {
473        let data = 1u64;
474        syscall!(libc::write(
475            self.fd.as_raw_fd(),
476            &data as *const _ as *const _,
477            std::mem::size_of::<u64>(),
478        ))?;
479        Ok(())
480    }
481}